-
Notifications
You must be signed in to change notification settings - Fork 318
[Refactor] Uniform PoDAttention API with Horizontal Fusion SMs Schedule #967
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Some of the unittests failed, for example (test_block_sparse_attention[False-256-16-16-128-64-16-4])
|
Hi, can I ask when this is planned to be merged? I made a PR to support POD Attn in SGLang using the old API and plan to get that working with CUDA graph first. |
I really like the uniform batch API that this PR presents. I ran this on an A100 and compared it with the existing FlashInfer POD-Attention implementation. On average this performed around 10 - 15% worse, but still better than serial execution. Performance was worse for larger prefill context lengths, while for smaller context lengths the performance was more comparable. |
Yeah this is more convenient, one issue i had during my PR is that I have to fill 2D attention mask for prefill every time, instead using page table & indices |
Will the old API be preserved? Thanks. |
@AKKamath Btw, I wonder what was the reason for using a mask instead of page table for prefill qkv? |
@yzh119 Can correct me here, but I believe the mask prefill kernel (single_prefill) had a better performance than the page table prefill because the page table prefill had a higher register usage causing register spills. |
But don't we waste lots of space storing the 2D mask? For example, the default shape is 2D cumulative seq lens (qo_lens, kv_lens), but when converting from page table qo_indptr, kv_indptr to the mask it will be very sparse, with each qo related to only a few kv entries of the request in the whole cumulative sequence. It can also be expensive to fill the mask |
Actually I realized POD Attention is not designed to mix many prefill requests with decode requests, it just mixes one prefill at a time, so that we can use causal without any custom mask |
Follow up in #1026 . |
std::accumulate(qo_len_ptr_h_p.begin(), qo_len_ptr_h_p.end(), 0) + | ||
2 * page_size * std::accumulate(kv_len_ptr_h_p.begin(), kv_len_ptr_h_p.end(), 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I'm interested in implementing a persistent POD Attn and have some questions. Here why don't we do qo_len_ptr_h_p[i] * kv_len_ptr_h_p[i] * 2
to model the quadratic compute load? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am using current calculation mainly for modeling memory load instead of compute load. For different workloads, this calculation can have different best heuristics. It will be helpful if you do benchmarking and decide.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, #1026 will be the upstream version and this PR has been deprecated. It would be helpful if you could refer directly to the new PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Do you have any plans to adapt POD Attn to the persistent template? I also plan to work on that
Description
This PR is a follow-up to #858, which integrates the PoDAttention (arXiv link) API in a user-transparent manner. Users can now invoke PoDAttention via the same API as
BatchPrefillWithPagedKVCache
, without explicitly specifying whether requests are prefill or decode (example code).Key Changes
Support for Non-Continuous Q/O and KV Tensor Layout
Previously, tensor offsets were computed using
indptr
, assuming continuous layouts. PoDAttention requires supporting mixed prefill/decode subsets within requests, necessitating a non-continuous layout.q_lenptr
andkv_lenptr
to accommodate this functionality (code link).Horizontal Fusion-Style Implementation
For improved efficiency, subsets of requests are aware of each other, enabling optimal selection of kernel hyperparameters and persistent kernel execution.
Limitations and Future Work
qo_len > threshold
) is preliminary and requires improvement (classifier implementation).cc @AKKamath @yzh119